// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/ir/transforms/constant_folding_pass.h"

#include <memory>
#include <string>
#include <unordered_map>

// NOTE(zhangbo9674): File pd_op.h is generated by op_gen.py, see details in
// paddle/fluid/ir/dialect/CMakeLists.txt.
#include "paddle/fluid/ir/dialect/pd_op.h"

#include "paddle/fluid/framework/new_executor/interpretercore.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/ir/dialect/pd_dialect.h"
#include "paddle/fluid/ir/dialect/pd_type.h"
#include "paddle/fluid/ir/transforms/pd_op_to_kernel_pass.h"
#include "paddle/fluid/ir/transforms/transform_general_functions.h"
#include "paddle/ir/core/builtin_op.h"
#include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/operation.h"
#include "paddle/ir/core/parameter.h"
#include "paddle/ir/core/program.h"
#include "paddle/ir/pass/pass.h"
#include "paddle/ir/pattern_rewrite/frozen_rewrite_pattern_set.h"
#include "paddle/ir/pattern_rewrite/pattern_match.h"
#include "paddle/ir/pattern_rewrite/pattern_rewrite_driver.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"

namespace {

class ConstantFoldingPattern : public ir::RewritePattern {
 public:
  ConstantFoldingPattern(ir::IrContext* context,
                         ir::PatternBenefit benefit = 1,
                         const std::vector<std::string>& generated_names = {})
      : RewritePattern(MatchAnyOpTypeTag(), benefit, context, generated_names) {
  }

  bool Match(ir::Operation* op) const override {
    // TODO(liuyuanle): Use trait to improve robustness.
    if (op->dyn_cast<ir::GetParameterOp>() ||
        op->dyn_cast<ir::SetParameterOp>() ||
        op->dyn_cast<paddle::dialect::FetchOp>())
      return false;

    // Inputs must come from get parameter op.
    for (uint32_t i = 0; i < op->num_operands(); ++i)
      if (ir::GetDefiningOpForInput(op, i)->dyn_cast<ir::GetParameterOp>() ==
          nullptr)
        return false;
    return true;
  }

  void Rewrite(ir::Operation* op,
               ir::PatternRewriter& rewriter) const override {  // NOLINT
    ir::Program* program = op->GetParentProgram();
    auto temp_program = BuildProgramFromOperation(op);

    std::vector<std::string> fetch_var_names;
    auto block = temp_program->block();
    for (auto it = block->begin(); it != block->end(); ++it) {
      if ((*it)->name() == "pd.fetch") {
        size_t index =
            (*it)->attributes().at("col").dyn_cast<ir::Int32Attribute>().data();

        if (fetch_var_names.size() < index + 1) {
          fetch_var_names.resize(index + 1);
        }

        fetch_var_names[index] = (*it)
                                     ->attributes()
                                     .at("name")
                                     .dyn_cast<ir::StrAttribute>()
                                     .AsString() +
                                 "@fetch";
      }
    }

    // Execute program
    paddle::framework::interpreter::ExecutionConfig exe_config;
    exe_config.create_local_scope = false;
    paddle::framework::InterpreterCore core(
        phi::CPUPlace{},
        fetch_var_names,
        paddle::dialect::PdOpLowerToKernelPass(temp_program.get()),
        &scope_,
        exe_config);

    paddle::framework::FetchList fetch_list = core.Run({});

    // TODO(liuyuanle): Support multiple output.
    auto out_tensor = PADDLE_GET_CONST(phi::DenseTensor, fetch_list[0]);
    std::unique_ptr<ir::Parameter> parameter = std::make_unique<ir::Parameter>(
        reinterpret_cast<void*>(out_tensor.data()),
        out_tensor.numel() * phi::SizeOf(out_tensor.dtype()),
        op->result(0).type());

    std::string param_name =
        "@constant_folding_pass@_" + std::to_string(suffix_++);

    auto* param_var = scope_.Var(param_name);
    auto* param_tensor = param_var->GetMutable<phi::DenseTensor>();
    *param_tensor = out_tensor;
    program->SetParameter(param_name, std::move(parameter));
    // rewriter.SetInsertionPoint(op);
    auto get_parameter_op =
        rewriter.Build<ir::GetParameterOp>(param_name, op->result(0).type());

    rewriter.ReplaceAllUsesWith(op->result(0), get_parameter_op->result(0));
    rewriter.EraseOp(op);
  }

 private:
  std::unique_ptr<ir::Program> BuildProgramFromOperation(
      ir::Operation* op) const {
    auto program = std::make_unique<ir::Program>(ir_context());
    ir::Builder builder = ir::Builder(ir_context(), program->block());

    // prepare op inputs
    std::vector<ir::OpResult> op_inputs;
    for (uint32_t i = 0; i < op->num_operands(); i++) {
      PADDLE_ENFORCE_EQ(
          op->operand_source(i).type().isa<paddle::dialect::DenseTensorType>(),
          true,
          phi::errors::InvalidArgument(
              "Op's input must be a dense tensor type."));

      auto [param_name, param] =
          ir::GetParameterFromValue(op->operand_source(i));
      program->SetParameter(param_name,
                            std::make_unique<ir::Parameter>(*param));

      auto* param_var = scope_.FindVar(param_name);
      PADDLE_ENFORCE_NOT_NULL(
          param_var,
          phi::errors::InvalidArgument("Parameter var not in scope."));

      auto get_parameter_op = builder.Build<ir::GetParameterOp>(
          param_name, op->operand_source(i).type());
      op_inputs.push_back(get_parameter_op->result(0));
    }

    // prepare op outputs
    std::vector<ir::Type> output_types;
    for (uint32_t i = 0; i < op->num_results(); i++) {
      output_types.push_back(op->result(i).type());
    }

    auto* temp_op =
        builder.Build(op_inputs, op->attributes(), output_types, op->info());

    // TODO(liuyuanle): Support multiple output.
    // for (uint32_t i = 0; i < op->num_results(); i++) {
    PADDLE_ENFORCE_EQ(
        temp_op->result(0).type().isa<paddle::dialect::DenseTensorType>(),
        true,
        phi::errors::InvalidArgument(
            "Op's output must be a dense tensor type."));

    builder.Build<paddle::dialect::FetchOp>(
        temp_op->result(0), "fetch_" + std::to_string(suffix_++), 0);
    // }

    return program;
  }

 private:
  static size_t suffix_;
  static paddle::framework::Scope scope_;
};

size_t ConstantFoldingPattern::suffix_ = 0;
paddle::framework::Scope ConstantFoldingPattern::scope_ = {};

class ConstantFoldingPass : public ir::Pass {
 public:
  // TODO(liuyuanle): Naming convention for pass.
  ConstantFoldingPass() : ir::Pass("ConstantFoldingPass", 1) {}

  bool Initialize(ir::IrContext* context) override {
    ir::RewritePatternSet ps(context);
    ps.Add<ConstantFoldingPattern>(context);
    patterns_ = ir::FrozenRewritePatternSet(std::move(ps));
    return true;
  }

  void Run(ir::Operation* op) override {
    ir::GreedyRewriteConfig cfg;
    cfg.use_top_down_traversal = true;
    cfg.max_iterations = 10;
    ir::ApplyPatternsGreedily(op->region(0), patterns_, cfg);
  }

  bool CanApplyOn(ir::Operation* op) const override {
    return op->name() == "builtin.module" && op->num_regions() > 0;
  }

 private:
  ir::FrozenRewritePatternSet patterns_;
};

}  // namespace

namespace ir {

std::unique_ptr<Pass> CreateConstantFoldingPass() {
  return std::make_unique<ConstantFoldingPass>();
}

}  // namespace ir
